import argparse
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import json
import os
import random
import gc
from typing import Optional
from datasets import load_dataset
from models.language_models import (
    Llama2_7b, Gemma7b, Vicuna13b, Mistral7b,
    PhiMini, Llama3_8b, Qwen7b, Zephyr7bR2D2,
    Mistral7B_RR, Llama3_8bRR
)


def dump_json(data, path: str):
    """
    Dump data as JSON to the specified file path.
    """
    with open(path, 'w') as f:
        json.dump(data, f, indent=2)


def get_model(model_name: str, device: str):
    """
    Selects and initializes a language model by name.
    """
    models = {
        'llama2-7b': Llama2_7b,
        'gemma-7b': Gemma7b,
        'vicuna-13b': Vicuna13b,
        'mistral7b': Mistral7b,
        'phi-mini': PhiMini,
        'llama3-8b': Llama3_8b,
        'qwen7b': Qwen7b,
        'r2d2': Zephyr7bR2D2,
        'mistral7b-rr': Mistral7B_RR,
        'llama3-8b-rr': Llama3_8bRR
    }
    if model_name not in models:
        raise ValueError(f"Unknown model name: {model_name}")
    return models[model_name](device=device)


def download_alpaca(path: str="./dataset/raw"):
    """
    Download the Alpaca dataset, filter instructions without inputs, and save to JSON.
    """
    hf_path = 'tatsu-lab/alpaca'
    dataset = load_dataset(hf_path)
    instructions = [
        item['instruction'].strip()
        for item in dataset['train']
        if item['input'].strip() == ''
    ]

    json_data = [{'instruction': instr, 'category': None} for instr in instructions]

    return json_data


def construct_harmless_dataset_splits(
    path: str = "./dataset/raw",
    train_p: float = 0.6,
    batch_size: int = 128
):
    """
    Construct train split from processed Alpaca instructions.
    """

    instructions = download_alpaca(path)

    random.seed(42)
    random.shuffle(instructions)

    total = len(instructions)
    train_end = int(train_p * total)

    splits = {
        'alpaca_train.json': instructions[:train_end],
    }

    for filename, subset in splits.items():
        out_path = os.path.join(path, filename)
        dump_json(subset[:batch_size], out_path)
        print(f"Saved {filename} ({len(subset)} samples) to {out_path}")



def convert_samples(
    model_name: str,
    device: str,
    batch_size: int = 128,
    save_dir: Optional[str] = './dataset/representations'
):
    """
    Convert prompts to hidden state representations and save them as a dataset.

    Args:
        model_name (str): Name of the model to use.
        device (str): Device string, e.g. 'cuda:0' or 'cpu'.
        batch_size (int): Number of samples to process (default: 128).
        save_dir (str): Directory to save the representations dataset.
    """

    construct_harmless_dataset_splits(batch_size=batch_size)
    # Load the processed dataset
    data_path = os.path.join('./dataset/raw', 'alpaca_train.json')
    with open(data_path, 'r') as f:
        data = json.load(f)
    actual_bs = len(data)

    print(f"Loaded {len(data)} samples from {data_path}")

    model = get_model(model_name, device)
    hidden_states = torch.zeros(actual_bs, model.num_layer, model.hidden_dimension)

    for i, prompt in enumerate(tqdm(data, desc=f"Processing {model_name}")):
        try:
            hidden_states[i] = model.get_representations(prompt=prompt['instruction'])
        except torch.cuda.OutOfMemoryError:
            print(f"CUDA out of memory for sample {i}, skipping.")
            torch.cuda.empty_cache()
            continue

    out_dir = os.path.join(save_dir, model_name)
    os.makedirs(out_dir, exist_ok=True)
    out_path = os.path.join(out_dir, "HLx_train.pt")
    torch.save(hidden_states, out_path)
    print(f"Saved hidden states to {out_path}")

    del model
    torch.cuda.empty_cache()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Download data, build splits, and convert prompts to hidden states."
    )
    parser.add_argument(
        '--model-names', '-m',
        nargs='+',
        required=True,
        help="Models to run (llama2-7b, gemma-7b, vicuna-13b, mistral7b, phi-mini, llama3-8b, qwen7b, r2d2, mistral7b-rr, zephyr, llama3-8b-rr)."
    )
    parser.add_argument(
        '--device', '-d',
        default='cuda:2',
        help="Device for inference (e.g. cuda:0 or cpu)."
    )
    parser.add_argument(
        '--batch-size', '-b',
        type=int,
        default=128,
        help="Number of samples to process."
    )
    parser.add_argument(
        '--save-dir', '-s',
        default='./dataset/representations',
        help="Directory to save hidden state files."
    )
    args = parser.parse_args()

    # Convert samples for each model
    for model_name in args.model_names:
        convert_samples(
            model_name=model_name,
            device=args.device,
            batch_size=args.batch_size,
            save_dir=args.save_dir
        )
